# This is the code used to generate the errors for Local SGD while varying both first and second order heterogeneity. Notably this is used for Figure 1a.

import numpy as np
import matplotlib.pyplot as plt

def sample_spherical_cap(axis, theta, n_samples):
    """
    Sample n_samples unit vectors within a spherical cap of half-angle theta around 'axis'.
    Uses rejection sampling.
    """
    d = axis.shape[0]
    axis = axis / np.linalg.norm(axis)
    cos_theta = np.cos(theta)
    if theta < 1e-6:
        return np.tile(axis, (n_samples, 1))
    
    samples = []
    while len(samples) < n_samples:
        v = np.random.normal(size=d)
        v /= np.linalg.norm(v)
        if np.dot(v, axis) >= cos_theta:
            samples.append(v)
    return np.array(samples)

def run_local_sgd(mu_list, x_star_list, sigma_noise, M, K, R, step_sizes):
    """
    Runs Local SGD for each step size and returns the best final error and its eta.
    """
    d = len(mu_list[0])
    best_error = np.inf
    best_eta = None
    for eta in step_sizes:
        x = np.zeros(d)
        for _ in range(R):
            x_locals = []
            for mu, x_star in zip(mu_list, x_star_list):
                x_local = x.copy()
                for _ in range(K):
                    beta = np.random.multivariate_normal(mu, np.eye(d))
                    y = beta.dot(x_star) + np.random.normal(scale=sigma_noise)
                    grad = (x_local.dot(beta) - y) * beta
                    x_local -= eta * grad
                x_locals.append(x_local)
            x = np.mean(x_locals, axis=0)
        error = np.linalg.norm(x - np.mean(x_star_list, axis=0))
        if error < best_error:
            best_error = error
            best_eta = eta
    return best_error, best_eta

def experiment(d=5, M=50, K=10, R=10, sigma_noise=0.1,
               mu0=5.0, R_star=1.0,
               tau_points=21, zeta_points=11,
               step_sizes=None, n_runs=20, seed=17):
    """
    Conducts the grid experiment:
    - Mean norm mu0 controls covariate shift radius
    - tau_list is now denser: tau ∈ [0, 2*mu0] with tau_points samples
    - zeta_list remains default: zeta ∈ [0, 2*R_star] with zeta_points samples
    - Averages over n_runs to reduce noise
    """
    np.random.seed(seed)
    # generate grids
    tau_list = np.linspace(0, 2 * mu0, tau_points)
    zeta_list = np.linspace(0, 2 * R_star, zeta_points)
    step_sizes = np.logspace(-3, -1, 5) if step_sizes is None else step_sizes

    print(f"Experiment: d={d}, M={M}, K={K}, R={R}, noise={sigma_noise}")
    print(f"μ₀={mu0}, τ grid size={tau_points}, ζ grid size={zeta_points}, runs={n_runs}\n")

    results_avg = np.zeros((len(zeta_list), len(tau_list)))
    eta_avg = np.zeros((len(zeta_list), len(tau_list)))
    central_axis = np.random.randn(d)
    central_axis /= np.linalg.norm(central_axis)

    for i, zeta in enumerate(zeta_list):
        phi = 2 * np.arcsin(min(zeta / (2 * R_star), 1.0))
        x_dirs = sample_spherical_cap(central_axis, phi, M)
        x_star_list = [R_star * v for v in x_dirs]
        print(f"[ζ {i+1}/{len(zeta_list)}] ζ={zeta:.2f}")

        for j, tau in enumerate(tau_list):
            theta = 2 * np.arcsin(min(tau / (2 * mu0), 1.0))
            errs, etas = [], []
            for _ in range(n_runs):
                mu_dirs = sample_spherical_cap(central_axis, theta, M)
                mu_list = [mu0 * u for u in mu_dirs]
                err, eta_sel = run_local_sgd(mu_list, x_star_list, sigma_noise, M, K, R, step_sizes)
                errs.append(err); etas.append(eta_sel)
            avg_err = np.mean(errs)
            avg_eta = np.mean(etas)
            results_avg[i, j] = avg_err
            eta_avg[i, j] = avg_eta
            print(f"  τ={tau:.2f} → err={avg_err:.4f}, η={avg_eta:.2e}")

    # Plot heatmap of averaged errors
    plt.figure(figsize=(7, 5))
    im = plt.imshow(results_avg, origin='lower',
                    extent=[tau_list[0], tau_list[-1], zeta_list[0], zeta_list[-1]],
                    aspect='auto')
    plt.colorbar(im, label='Avg Best Error')
    plt.xlabel('Covariate shift')
    plt.ylabel('Concept shift')
    plt.title(f'Local SGD Avg Best Error')
    plt.show()

# Run with the updated, denser covariate-shift grid
experiment()
